// MSM-TVTP, AR(1), 2 REGIMES, KNOWN VARIANCE ACROSS BOTH REGIMES
// only parameter that varies across regimes in intercept in outcome and TVTP

data {
  int<lower = 0> T; // number of observed time periods - 1
  vector[T] y_head; // head(y, T - 1)
  vector[T] y_tail; // tail(y, T - 1)
  vector[T] z; // exogenous random variable affecting transition probabilities
//  real<lower = 0> sigma; // variance; fixed for now
}

parameters {
  ordered[2] alpha;
  real<lower=0,upper=1> phi1; //AR parameter
  real<lower=0,upper=1> phi2; //AR parameter
  real gamma; // intercepts in TPs
  real lambda[2]; // coefficient in TPs
  real<lower = 0> sigma; // variance term

}
transformed parameters {
  real p11[T];
  real p12[T];
  real p21[T];
  real p22[T];
  real p_sprev_1_givenprev[T]; // P(S[t - 1] = 1 | omega[t - 1], theta)
  real p_scur_1_givenprev[T];  // P(S[t] = 1 | omega[t - 1], theta)
  real p_scur_1_givencur[T];   // P(S[t] = 1 | omega[t], theta)
  real s1[T];
  real s2[T];
  real fy[T];
  for (t in 1:T) {
    p11[t] = normal_cdf(gamma + lambda[1] * z[t], 0, 1);
    p21[t] = 1 - p11[t];
    p12[t] = normal_cdf(gamma + lambda[2] * z[t], 0, 1);
    p22[t] = 1 - p12[t];
  }
  // t = 1
  // Initial transition probabilities (Piger eq. 14)
  // Can also treat initial transition probabilities as a parameter to be estimated
  p_sprev_1_givenprev[1] = (1 - p22[1]) / (2 - p11[1] - p22[1]);
  p_scur_1_givenprev[1] = p11[1] * p_sprev_1_givenprev[1] + 
                       p12[1] * (1 - p_sprev_1_givenprev[1]);
  s1[1] = normal_lpdf(y_tail[1] | alpha[1] + phi1 * y_head[1], sigma);
  s2[1] = normal_lpdf(y_tail[1] | alpha[2] + phi2 * y_head[1], sigma);
  
  // Piger eq. 11
  fy[1] = log_mix(p_scur_1_givenprev[1], s1[1], s2[1]);
  
  // Piger eq. 13
  p_scur_1_givencur[1] = exp(s1[1] + log(p_scur_1_givenprev[1]) - fy[1]);
  
  for (t in 2:T) {
    p_sprev_1_givenprev[t] = p_scur_1_givencur[(t-1)];

    // Piger eq. 10
    p_scur_1_givenprev[t] = p11[t] * p_sprev_1_givenprev[t] + 
                         p12[t] * (1 - p_sprev_1_givenprev[t]);
    
    s1[t] = normal_lpdf(y_tail[t] | alpha[1] + phi1 * y_head[t], sigma);
    s2[t] = normal_lpdf(y_tail[t] | alpha[2] + phi2 * y_head[t], sigma);

    // Piger eq. 11                      
    fy[t] = log_mix(p_scur_1_givenprev[t], s1[t], s2[t]);
    
    // Piger eq. 13
    p_scur_1_givencur[t] = exp(s1[t] + log(p_scur_1_givenprev[t]) - fy[t]);
  }
  
}
model {
  // likelihood
  
  for (t in 1:T) {
    target += fy[t];
  }

  // priors
  gamma ~ normal(0, 1);
  phi1 ~ uniform(0, 1);
  phi2 ~ uniform(0, 1);
  lambda[1] ~ normal(0, 1);
  lambda[2] ~ normal(0, 1);
  alpha[1] ~ normal(0,1);
  alpha[2] ~ normal(0,1);
}




